/*
 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
 *
 * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
 */

use common::{
    config::smtp::queue::{QueueExpiry, QueueName},
    expr::{self, functions::ResolveVariable, *},
};
use compact_str::ToCompactString;
use smtp_proto::Response;
use std::{
    fmt::Display,
    net::{IpAddr, Ipv4Addr},
    time::{Duration, Instant, SystemTime},
};
use store::write::now;
use types::blob_hash::BlobHash;
use utils::DomainPart;

pub mod dsn;
pub mod manager;
pub mod quota;
pub mod spool;
pub mod throttle;

pub type QueueId = u64;

#[derive(Debug, Clone, rkyv::Serialize, rkyv::Deserialize, rkyv::Archive, serde::Deserialize)]
pub struct Schedule<T> {
    pub due: u64,
    pub inner: T,
}

#[derive(Debug, Clone, Copy)]
pub struct QueuedMessage {
    pub due: u64,
    pub queue_id: QueueId,
    pub queue_name: QueueName,
}

#[derive(Debug, Clone, Copy)]
pub enum MessageSource {
    Authenticated,
    Unauthenticated(bool),
    Dsn,
    Report,
    Autogenerated,
}

#[derive(rkyv::Serialize, rkyv::Deserialize, rkyv::Archive, Debug, Clone, PartialEq, Eq)]
pub struct Message {
    pub created: u64,
    pub blob_hash: BlobHash,

    pub return_path: String,
    pub recipients: Vec<Recipient>,

    pub received_from_ip: IpAddr,
    pub received_via_port: u16,

    pub flags: u64,
    pub env_id: Option<String>,
    pub priority: i16,

    pub size: u64,
    pub quota_keys: Vec<QuotaKey>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageWrapper {
    pub queue_id: QueueId,
    pub queue_name: QueueName,
    pub is_multi_queue: bool,
    pub span_id: u64,
    pub message: Message,
}

#[derive(
    rkyv::Serialize,
    rkyv::Deserialize,
    rkyv::Archive,
    Debug,
    Clone,
    PartialEq,
    Eq,
    serde::Deserialize,
)]
pub enum QuotaKey {
    Size { key: Vec<u8>, id: u64 },
    Count { key: Vec<u8>, id: u64 },
}

#[derive(
    rkyv::Serialize,
    rkyv::Deserialize,
    rkyv::Archive,
    Debug,
    Clone,
    PartialEq,
    Eq,
    serde::Deserialize,
)]
pub struct Recipient {
    pub address: String,

    pub retry: Schedule<u32>,
    pub notify: Schedule<u32>,
    pub expires: QueueExpiry,

    pub queue: QueueName,
    pub status: Status<HostResponse<String>, ErrorDetails>,
    pub flags: u64,
    pub orcpt: Option<String>,
}

pub const FROM_AUTHENTICATED: u64 = 1 << 32;
pub const FROM_UNAUTHENTICATED: u64 = 1 << 33;
pub const FROM_UNAUTHENTICATED_DMARC: u64 = 1 << 34;
pub const FROM_DSN: u64 = 1 << 35;
pub const FROM_REPORT: u64 = 1 << 36;
pub const FROM_AUTOGENERATED: u64 = 1 << 37;

pub const RCPT_DSN_SENT: u64 = 1 << 32;
pub const RCPT_STATUS_CHANGED: u64 = 1 << 33;

#[derive(
    Debug,
    Clone,
    PartialEq,
    Eq,
    rkyv::Serialize,
    rkyv::Deserialize,
    rkyv::Archive,
    serde::Serialize,
    serde::Deserialize,
)]
pub enum Status<T, E> {
    #[serde(rename = "scheduled")]
    Scheduled,
    #[serde(rename = "completed")]
    Completed(T),
    #[serde(rename = "temp_fail")]
    TemporaryFailure(E),
    #[serde(rename = "perm_fail")]
    PermanentFailure(E),
}

#[derive(
    Debug,
    Clone,
    PartialEq,
    Eq,
    rkyv::Serialize,
    rkyv::Deserialize,
    rkyv::Archive,
    serde::Deserialize,
)]
pub struct HostResponse<T> {
    pub hostname: T,
    pub response: Response<String>,
}

#[derive(
    Debug,
    Clone,
    PartialEq,
    Eq,
    rkyv::Serialize,
    rkyv::Deserialize,
    rkyv::Archive,
    serde::Deserialize,
    Default,
)]
pub enum Error {
    DnsError(String),
    UnexpectedResponse(UnexpectedResponse),
    ConnectionError(String),
    TlsError(String),
    DaneError(String),
    MtaStsError(String),
    RateLimited,
    #[default]
    ConcurrencyLimited,
    Io(String),
}

#[derive(
    Debug,
    Clone,
    PartialEq,
    Eq,
    rkyv::Serialize,
    rkyv::Deserialize,
    rkyv::Archive,
    serde::Deserialize,
)]
pub struct UnexpectedResponse {
    pub command: String,
    pub response: Response<String>,
}

#[derive(
    Debug,
    Clone,
    PartialEq,
    Eq,
    rkyv::Serialize,
    rkyv::Deserialize,
    rkyv::Archive,
    Default,
    serde::Deserialize,
)]
pub struct ErrorDetails {
    pub entity: String,
    pub details: Error,
}

impl<T> Ord for Schedule<T> {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        other.due.cmp(&self.due)
    }
}

impl<T> PartialOrd for Schedule<T> {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl<T> PartialEq for Schedule<T> {
    fn eq(&self, other: &Self) -> bool {
        self.due == other.due
    }
}

impl<T> Eq for Schedule<T> {}

impl<T: Default> Schedule<T> {
    pub fn now() -> Self {
        Schedule {
            due: now(),
            inner: T::default(),
        }
    }

    pub fn later(duration: u64) -> Self {
        Schedule {
            due: now() + duration,
            inner: T::default(),
        }
    }
}

pub struct QueueEnvelope<'x> {
    pub message: &'x Message,
    pub domain: &'x str,
    pub mx: &'x str,
    pub rcpt: &'x Recipient,
    pub remote_ip: IpAddr,
    pub local_ip: IpAddr,
}

impl<'x> QueueEnvelope<'x> {
    pub fn new(message: &'x Message, rcpt: &'x Recipient) -> Self {
        Self {
            message,
            domain: rcpt.address.domain_part(),
            rcpt,
            mx: "",
            remote_ip: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
            local_ip: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
        }
    }
}

impl<'x> ResolveVariable for QueueEnvelope<'x> {
    fn resolve_variable(&self, variable: u32) -> expr::Variable<'x> {
        match variable {
            V_SENDER => self.message.return_path.as_str().into(),
            V_SENDER_DOMAIN => self.message.return_path.domain_part().into(),
            V_RECIPIENT_DOMAIN => self.domain.into(),
            V_RECIPIENT => self.rcpt.address.as_str().into(),
            V_RECIPIENTS => self
                .message
                .recipients
                .iter()
                .map(|r| Variable::from(r.address.as_str()))
                .collect::<Vec<_>>()
                .into(),
            V_QUEUE_RETRY_NUM => self.rcpt.retry.inner.into(),
            V_QUEUE_NOTIFY_NUM => self.rcpt.notify.inner.into(),
            V_QUEUE_EXPIRES_IN => match &self.rcpt.expires {
                QueueExpiry::Ttl(time) => (*time + self.message.created).saturating_sub(now()),
                QueueExpiry::Attempts(count) => {
                    (count.saturating_sub(self.rcpt.retry.inner)) as u64
                }
            }
            .into(),
            V_QUEUE_LAST_STATUS => self.rcpt.status.to_compact_string().into(),
            V_QUEUE_LAST_ERROR => match &self.rcpt.status {
                Status::Scheduled | Status::Completed(_) => "none",
                Status::TemporaryFailure(err) | Status::PermanentFailure(err) => {
                    match &err.details {
                        Error::DnsError(_) => "dns",
                        Error::UnexpectedResponse(_) => "unexpected-reply",
                        Error::ConnectionError(_) => "connection",
                        Error::TlsError(_) => "tls",
                        Error::DaneError(_) => "dane",
                        Error::MtaStsError(_) => "mta-sts",
                        Error::RateLimited => "rate",
                        Error::ConcurrencyLimited => "concurrency",
                        Error::Io(_) => "io",
                    }
                }
            }
            .into(),
            V_QUEUE_NAME => self.rcpt.queue.as_str().into(),
            V_QUEUE_AGE => now().saturating_sub(self.message.created).into(),
            V_SOURCE => if (self.message.flags & FROM_AUTHENTICATED) != 0 {
                "authenticated"
            } else if (self.message.flags & FROM_UNAUTHENTICATED_DMARC) != 0 {
                "dmarc_pass"
            } else if (self.message.flags & FROM_UNAUTHENTICATED) != 0 {
                "unauthenticated"
            } else if (self.message.flags & FROM_DSN) != 0 {
                "dsn"
            } else if (self.message.flags & FROM_REPORT) != 0 {
                "report"
            } else if (self.message.flags & FROM_AUTOGENERATED) != 0 {
                "autogenerated"
            } else {
                "unknown"
            }
            .into(),
            V_MX => self.mx.into(),
            V_PRIORITY => self.message.priority.into(),
            V_REMOTE_IP => self.remote_ip.to_compact_string().into(),
            V_LOCAL_IP => self.local_ip.to_compact_string().into(),
            V_RECEIVED_FROM_IP => self.message.received_from_ip.to_compact_string().into(),
            V_RECEIVED_VIA_PORT => self.message.received_via_port.into(),
            V_SIZE => self.message.size.into(),
            _ => "".into(),
        }
    }

    fn resolve_global(&self, _: &str) -> Variable<'_> {
        Variable::Integer(0)
    }
}

impl ResolveVariable for Message {
    fn resolve_variable(&self, variable: u32) -> expr::Variable<'_> {
        match variable {
            V_SENDER => self.return_path.as_str().into(),
            V_SENDER_DOMAIN => self.return_path.domain_part().into(),
            V_RECIPIENTS => self
                .recipients
                .iter()
                .map(|r| Variable::from(r.address.as_str()))
                .collect::<Vec<_>>()
                .into(),
            V_PRIORITY => self.priority.into(),
            _ => "".into(),
        }
    }

    fn resolve_global(&self, _: &str) -> Variable<'_> {
        Variable::Integer(0)
    }
}

pub struct RecipientDomain<'x>(&'x str);

impl<'x> RecipientDomain<'x> {
    pub fn new(domain: &'x str) -> Self {
        Self(domain)
    }
}

impl<'x> ResolveVariable for RecipientDomain<'x> {
    fn resolve_variable(&self, variable: u32) -> expr::Variable<'x> {
        match variable {
            V_RECIPIENT_DOMAIN => self.0.into(),
            _ => "".into(),
        }
    }

    fn resolve_global(&self, _: &str) -> Variable<'_> {
        Variable::Integer(0)
    }
}

#[inline(always)]
pub fn instant_to_timestamp(now: Instant, time: Instant) -> u64 {
    SystemTime::now()
        .duration_since(SystemTime::UNIX_EPOCH)
        .map_or(0, |d| d.as_secs())
        + time.checked_duration_since(now).map_or(0, |d| d.as_secs())
}

impl Recipient {
    pub fn new(address: impl AsRef<str>) -> Self {
        Recipient {
            address: address.to_lowercase_domain(),
            status: Status::Scheduled,
            flags: 0,
            orcpt: None,
            retry: Schedule::now(),
            notify: Schedule::now(),
            expires: QueueExpiry::Attempts(0),
            queue: QueueName::default(),
        }
    }

    pub fn with_flags(mut self, flags: u64) -> Self {
        self.flags = flags;
        self
    }

    pub fn with_orcpt(mut self, orcpt: Option<String>) -> Self {
        self.orcpt = orcpt;
        self
    }

    pub fn address(&self) -> &str {
        &self.address
    }

    pub fn domain_part(&self) -> &str {
        self.address.domain_part()
    }
}

impl ArchivedRecipient {
    pub fn address(&self) -> &str {
        self.address.as_str()
    }

    pub fn domain_part(&self) -> &str {
        self.address.domain_part()
    }
}

pub trait InstantFromTimestamp {
    fn to_instant(&self) -> Instant;
}

impl InstantFromTimestamp for u64 {
    fn to_instant(&self) -> Instant {
        let timestamp = *self;
        let current_timestamp = SystemTime::now()
            .duration_since(SystemTime::UNIX_EPOCH)
            .map_or(0, |d| d.as_secs());
        if timestamp > current_timestamp {
            Instant::now() + Duration::from_secs(timestamp - current_timestamp)
        } else {
            Instant::now()
        }
    }
}

impl Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Error::UnexpectedResponse(response) => {
                write!(
                    f,
                    "Unexpected response for {}: {}",
                    response.command, response.response
                )
            }
            Error::DnsError(err) => {
                write!(f, "DNS lookup failed: {err}")
            }
            Error::ConnectionError(details) => {
                write!(f, "Connection failed: {details}",)
            }
            Error::TlsError(details) => {
                write!(f, "TLS error: {details}",)
            }
            Error::DaneError(details) => {
                write!(f, "DANE authentication failure: {details}",)
            }
            Error::MtaStsError(details) => {
                write!(f, "MTA-STS auth failed: {details}")
            }
            Error::RateLimited => {
                write!(f, "Rate limited")
            }
            Error::ConcurrencyLimited => {
                write!(f, "Too many concurrent connections to remote server")
            }
            Error::Io(err) => {
                write!(f, "Queue error: {err}")
            }
        }
    }
}

impl Display for ArchivedError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ArchivedError::UnexpectedResponse(response) => {
                write!(
                    f,
                    "Unexpected response for {}: {}",
                    response.command, response.response
                )
            }
            ArchivedError::DnsError(err) => {
                write!(f, "DNS lookup failed: {err}")
            }
            ArchivedError::ConnectionError(details) => {
                write!(f, "Connection failed: {details}",)
            }
            ArchivedError::TlsError(details) => {
                write!(f, "TLS error: {details}",)
            }
            ArchivedError::DaneError(details) => {
                write!(f, "DANE authentication failure: {details}",)
            }
            ArchivedError::MtaStsError(details) => {
                write!(f, "MTA-STS auth failed: {details}")
            }
            ArchivedError::RateLimited => {
                write!(f, "Rate limited")
            }
            ArchivedError::ConcurrencyLimited => {
                write!(f, "Too many concurrent connections to remote server")
            }
            ArchivedError::Io(err) => {
                write!(f, "Queue error: {err}")
            }
        }
    }
}

impl Display for Status<HostResponse<String>, ErrorDetails> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Status::Scheduled => write!(f, "Scheduled"),
            Status::Completed(response) => write!(f, "Delivered: {}", response.response),
            Status::TemporaryFailure(err) => {
                write!(f, "Temporary Failure for {}: {}", err.entity, err.details)
            }
            Status::PermanentFailure(err) => {
                write!(f, "Permanent Failure for {}: {}", err.entity, err.details)
            }
        }
    }
}

impl Display for ArchivedErrorDetails {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Error for {}: {}", self.entity, self.details)
    }
}

/*

pub trait DisplayArchivedResponse {
    fn to_string(&self) -> String;
}

impl DisplayArchivedResponse for ArchivedResponse<String> {
    fn to_string(&self) -> String {
        format!(
            "Code: {}, Enhanced code: {}.{}.{}, Message: {}",
            self.code, self.esc[0], self.esc[1], self.esc[2], self.message,
        )
    }
}
*/
